#include <CL/sycl.hpp>
#include <iostream>
#include <random>

using namespace std;
using namespace sycl;

constexpr int N = 1024;
float A[N][N];

void reset()
{
    // 设置随机种子
    const unsigned int seed = 2;
    // 使用固定种子初始化随机数生成器
    std::mt19937 gen(seed);
    std::uniform_int_distribution<> dis;
    A[0][0]=0.0;
    // 初始化矩阵
    for (int i = 0; i < N; i++)
    {
        A[i][i] = 1.0;
        for (int j = i + 1; j < N; j++)
        {
            A[i][j] = dis(gen);//上三角赋值
        }
    }

    for (int k = 0; k < N; k++)
    {

        for (int i = k + 1; i < N; i++)
        {
            for (int j = 0; j < N; j++)
            {
                A[i][j] += A[k][j];//将上一行的数加到其下的每一行

            }
        }

    }

}

void gauss_oneapi(buffer<float, 2> &buf, queue &q)
{
    //device my_device = q.get_device();
    //std::cout << "Device: " << my_device.get_info<info::device::name>() << std::endl;
    int n = buf.get_range()[0];
    for (int k = 0; k < n; k++)
    {
        q.submit([&](handler &h){
        accessor m{ buf, h, read_write };
            
        // 第一步：将第k行的元素除以第k列的主元素
        h.parallel_for(range(n - k), [=](auto idx) 
        {int j = k + idx;m[k][j] = m[k][j] / m[k][k];}); });
        q.submit([&](handler &h)
        {accessor m{ buf, h, read_write };
         
        // 第二步：对第k+1行到最后一行的元素进行消除操作
        h.parallel_for(range(n - (k + 1), n - (k + 1)), [=](auto idx) 
        {int i = k + 1 + idx.get_id(0);
         int j = k + 1 + idx.get_id(1);
         m[i][j] = m[i][j] - m[i][k] * m[k][j];}); });
        
        // 第三步：将第k+1行到最后一行的第k列元素置零
        q.submit([&](handler &h)
                 {accessor m{ buf, h, read_write };
        h.parallel_for(range(n - (k + 1)), [=](auto idx) 
        {int i = k + 1 + idx;m[i][k] = 0;}); });
    }
    q.wait();
}

int main() {


    default_selector selector;
    queue q(selector);
    reset();

    // 创建 SYCL buffers
    buffer<float, 2> A_buffer(reinterpret_cast<float*>(A), range<2>(N, N));
    // Start timer
    auto start = chrono::high_resolution_clock::now();

    gauss_oneapi(A_buffer, q);
        
    // End timer
    auto end = chrono::high_resolution_clock::now();
    chrono::duration<double> duration = end - start;
    cout << "Time: " << duration.count() << "s" << std::endl;

    // 更新buffer到host
    A_buffer.get_access<access::mode::read>();

    return 0;
}
